Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 38 additions & 97 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,140 +133,81 @@ class InferenceSession::Impl {
return Status::OK();
}

template <typename T>
common::Status Load(const T& model_uri) {
common::Status Load(std::function<common::Status(std::shared_ptr<Model>&)> loader, const std::string& event_name) {
Status status = Status::OK();
auto tp = session_profiler_.StartTime();
try {
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED,
"This session already contains a loaded model.");
}

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(model_uri, p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
status = loader(p_tmp_model);
ORT_RETURN_IF_ERROR(status);

model_ = p_tmp_model;

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));
status = DoPostLoadProcessing(*model_);
ORT_RETURN_IF_ERROR(status);

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
status = Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}

if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_uri", tp);
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, event_name, tp);
}
return common::Status::OK();

return status;
}

common::Status Load(const ModelProto& model_proto) {
auto tp = session_profiler_.StartTime();
try {
LOGS(*session_logger_, INFO) << "Loading model using model_proto";
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(model_proto, p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
model_ = p_tmp_model;
template <typename T>
common::Status Load(const T& model_uri) {
auto loader = [this, &model_uri](std::shared_ptr<onnxruntime::Model>& model) {
return onnxruntime::Model::Load(model_uri, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));
return Load(loader, "model_loading_uri");
}

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;
common::Status Load(const ModelProto& model_proto) {
auto loader = [this, &model_proto](std::shared_ptr<onnxruntime::Model>& model) {
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

LOGS(*session_logger_, INFO) << "Model successfully loaded.";
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
}
return Status::OK();
return Load(loader, "model_loading_proto");
}

common::Status Load(std::unique_ptr<ModelProto> p_model_proto) {
auto tp = session_profiler_.StartTime();
try {
LOGS(*session_logger_, INFO) << "Loading model using model_proto";
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}
auto loader = [this, &p_model_proto](std::shared_ptr<onnxruntime::Model>& model) {
return onnxruntime::Model::Load(std::move(p_model_proto), model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(std::move(p_model_proto), p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
model_ = p_tmp_model;

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;

LOGS(*session_logger_, INFO) << "Model successfully loaded.";
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
}
return Status::OK();
return Load(loader, "model_loading_proto");
}

common::Status Load(std::istream& model_istream) {
auto tp = session_profiler_.StartTime();
try {
LOGS(*session_logger_, INFO) << "Loading model using istream";
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}

auto loader = [this, &model_istream](std::shared_ptr<onnxruntime::Model>& model) {
ModelProto model_proto;
const bool result = model_proto.ParseFromIstream(&model_istream);
if (!result) {
return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, "Failed to load model because protobuf parsing failed.");
return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF,
"Failed to load model because protobuf parsing failed.");
}

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(model_proto, p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
model_ = p_tmp_model;

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;

LOGS(*session_logger_, INFO) << "Model successfully loaded.";
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
if (session_profiler_.FEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_istream", tp);
}
return common::Status::OK();
return Load(loader, "model_loading_istream");
}

static common::Status TransformGraph(onnxruntime::Graph& graph,
Expand Down
60 changes: 39 additions & 21 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

using namespace std;
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
using namespace onnxruntime::logging;

namespace onnxruntime {
class FuseAdd : public OpKernel {
Expand Down Expand Up @@ -77,7 +77,8 @@ class FuseExecutionProvider : public IExecutionProvider {
public:
explicit FuseExecutionProvider() {
DeviceAllocatorRegistrationInfo device_info({OrtMemTypeDefault,
[](int) { return std::make_unique<CPUAllocator>(); }, std::numeric_limits<size_t>::max()});
[](int) { return std::make_unique<CPUAllocator>(); },
std::numeric_limits<size_t>::max()});
InsertAllocator(std::shared_ptr<IArenaAllocator>(
std::make_unique<DummyArena>(device_info.factory(0))));
}
Expand All @@ -91,7 +92,7 @@ class FuseExecutionProvider : public IExecutionProvider {
for (auto& node : graph.Nodes()) {
sub_graph->nodes.push_back(node.Index());
}
auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>();
auto meta_def = std::make_unique<IndexedSubGraph::MetaDef>();
meta_def->name = "FuseAdd";
meta_def->domain = "FuseTest";
meta_def->inputs = {"X", "Y", "Z"};
Expand All @@ -103,7 +104,7 @@ class FuseExecutionProvider : public IExecutionProvider {
return result;
}

std::shared_ptr<::onnxruntime::KernelRegistry> GetKernelRegistry() const override {
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override {
static std::shared_ptr<KernelRegistry> kernel_registry = GetFusedKernelRegistry();
return kernel_registry;
}
Expand Down Expand Up @@ -134,7 +135,8 @@ static void CreateMatMulModel(std::unique_ptr<onnxruntime::Model>& p_model, Prov
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[onnxruntime::kOnnxDomain] = 7;
// Generate the input & output def lists
p_model = std::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
p_model = std::make_unique<onnxruntime::Model>("test", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version);
onnxruntime::Graph& graph = p_model->MainGraph();

TypeProto tensor_float;
Expand Down Expand Up @@ -171,7 +173,8 @@ void VerifyOutputs(const std::vector<MLValue>& fetches,
auto& rtensor = fetches.front().Get<Tensor>();
TensorShape expected_shape(expected_dims);
ASSERT_EQ(expected_shape, rtensor.Shape());
const std::vector<float> found(rtensor.template Data<float>(), rtensor.template Data<float>() + expected_values.size());
const std::vector<float> found(rtensor.template Data<float>(),
rtensor.template Data<float>() + expected_values.size());
ASSERT_EQ(expected_values, found);
}

Expand All @@ -182,7 +185,8 @@ void RunModel(InferenceSession& session_object,
std::vector<int64_t> dims_mul_x = {3, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
MLValue ml_value;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value));

Expand All @@ -194,7 +198,8 @@ void RunModel(InferenceSession& session_object,
if (is_preallocate_output_vec) {
fetches.resize(output_names.size());
for (auto& elem : fetches) {
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &elem);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&elem);
}
}

Expand Down Expand Up @@ -237,7 +242,8 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,

MLValue input_ml_value_B;
std::vector<int64_t> dims_mul_x_B = {4, 3};
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x_B, values_mul_x, &input_ml_value_B);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x_B, values_mul_x,
&input_ml_value_B);

io_binding->BindInput("A", input_ml_value_A);
io_binding->BindInput("B", input_ml_value_B);
Expand All @@ -247,10 +253,12 @@ void RunModelWithBindingMatMul(InferenceSession& session_object,
MLValue output_ml_value;
if (is_preallocate_output_vec) {
if (allocation_provider == kCpuExecutionProvider) {
AllocateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), expected_output_dims, &output_ml_value);
AllocateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), expected_output_dims,
&output_ml_value);
} else if (allocation_provider == kCudaExecutionProvider) {
#ifdef USE_CUDA
AllocateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), expected_output_dims, &output_ml_value);
AllocateMLValue<float>(TestCudaExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), expected_output_dims,
&output_ml_value);
#endif
} else {
ORT_THROW("Unsupported provider");
Expand Down Expand Up @@ -430,7 +438,8 @@ TEST(InferenceSessionTests, CheckRunLogger) {
auto capturing_sink = new CapturingSink();

auto logging_manager = std::make_unique<logging::LoggingManager>(
std::unique_ptr<ISink>(capturing_sink), logging::Severity::kVERBOSE, false, LoggingManager::InstanceType::Temporal);
std::unique_ptr<ISink>(capturing_sink), logging::Severity::kVERBOSE, false,
LoggingManager::InstanceType::Temporal);

InferenceSession session_object{so, logging_manager.get()};
ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK());
Expand All @@ -446,7 +455,9 @@ TEST(InferenceSessionTests, CheckRunLogger) {
std::copy(msgs.begin(), msgs.end(), std::ostream_iterator<std::string>(std::cout, "\n"));
bool have_log_entry_with_run_tag =
(std::find_if(msgs.begin(), msgs.end(),
[&run_options](std::string msg) { return msg.find(run_options.run_tag) != string::npos; }) != msgs.end());
[&run_options](std::string msg) {
return msg.find(run_options.run_tag) != string::npos;
}) != msgs.end());

ASSERT_TRUE(have_log_entry_with_run_tag);
#endif
Expand Down Expand Up @@ -750,7 +761,8 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) {
std::vector<int64_t> dims_mul_x = {3, 2};
std::vector<int64_t> values_mul_x = {1, 2, 3, 4, 5, 6};
MLValue ml_value;
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value);
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value));

Expand Down Expand Up @@ -991,11 +1003,14 @@ TEST(ExecutionProviderTest, FunctionTest) {
std::vector<int64_t> dims_mul_x = {3, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
MLValue ml_value_x;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_x);
MLValue ml_value_y;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_y);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_y);
MLValue ml_value_z;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_z);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
Expand All @@ -1016,7 +1031,7 @@ TEST(ExecutionProviderTest, FunctionTest) {
VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m);

InferenceSession session_object_2{so};
session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>());
session_object_2.RegisterExecutionProvider(std::make_unique<FuseExecutionProvider>());
status = session_object_2.Load(model_file_name);
ASSERT_TRUE(status.IsOK());
status = session_object_2.Initialize();
Expand Down Expand Up @@ -1094,11 +1109,14 @@ TEST(ExecutionProviderTest, FunctionInlineTest) {
std::vector<int64_t> dims_mul_x = {2, 2};
std::vector<float> values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f};
MLValue ml_value_x;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_x);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_x);
MLValue ml_value_y;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_y);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_y);
MLValue ml_value_z;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x, &ml_value_z);
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_mul_x, values_mul_x,
&ml_value_z);
NameMLValMap feeds;
feeds.insert(std::make_pair("X", ml_value_x));
feeds.insert(std::make_pair("Y", ml_value_y));
Expand Down